{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import time\n",
    "import sys\n",
    "if \"../\" not in sys.path:\n",
    "  sys.path.append(\"../\") \n",
    "from lib.utils import read_data_from_file, sign, sigmoid\n",
    "from lib.logistic_reg import logistic_reg, logistic_reg_sgd\n",
    "from sklearn.preprocessing import PolynomialFeatures"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_err_rate(X, y, w):\n",
    "    y_hat = sign(X.dot(w))\n",
    "    err_rate = (y_hat != y).mean()\n",
    "    return err_rate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "data_train shape:  (1000, 21)\n",
      "data_test shape:  (3000, 21)\n",
      "[[0.57548   0.53938   0.72311   0.23702   0.95864   0.85208   0.68642\n",
      "  0.35431   0.0095435 0.93523   0.21656   0.62107   0.31371   0.82675\n",
      "  0.61655   0.17468   0.12009   0.38317   0.35514   0.39439   1.       ]\n",
      " [0.70727   0.88503   0.62762   0.93851   0.20865   0.82238   0.08001\n",
      "  0.22381   0.18949   0.57738   0.39569   0.89592   0.37106   0.71963\n",
      "  0.5582    0.067821  0.29071   0.39012   0.68854   0.077076  1.       ]\n",
      " [0.24443   0.27331   0.58282   0.92196   0.79214   0.62849   0.67165\n",
      "  0.88523   0.52063   0.81997   0.15601   0.62972   0.48671   0.6272\n",
      "  0.66056   0.79822   0.17185   0.52541   0.93579   0.69411   1.       ]]\n"
     ]
    }
   ],
   "source": [
    "# 数据读取\n",
    "data_train = read_data_from_file('hw3_train.dat') \n",
    "print('data_train shape: ', data_train.shape)\n",
    "data_test = read_data_from_file('hw3_test.dat')\n",
    "print('data_test shape: ', data_test.shape)\n",
    "y = data_train[:,-1]\n",
    "X = np.concatenate((np.ones((data_train.shape[0],1)), data_train[:,:-1]), axis=1)\n",
    "y_test = data_test[:,-1]\n",
    "X_test = np.concatenate((np.ones((data_test.shape[0],1)), data_test[:,:-1]), axis=1)\n",
    "print(data_train[:3])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "18. For Questions 18-20, you will play with logistic regression. Please use the following set for training:<br/><br/>https://www.csie.ntu.edu.tw/~htlin/mooc/datasets/mlfound_algo/hw3_train.dat<br/><br/>and the following set for testing:<br/><br/>\n",
    "https://www.csie.ntu.edu.tw/~htlin/mooc/datasets/mlfound_algo/hw3_test.dat<br/><br/>Implement the fixed learning rate gradient descent algorithm for logistic regression. Run the algorithm with $\\eta=0.001$ and $T = 2000$. What is $E_{out}(g)$ from your algorithm, evaluated using the 0/1 error on the test set?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "episode: 2000"
     ]
    }
   ],
   "source": [
    "w_reg, _ = logistic_reg(X, y, eta=0.001, updates=2000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Ein:   0.466\n",
      "Eout:  0.475\n"
     ]
    }
   ],
   "source": [
    "print('Ein:  ', get_err_rate(X, y, w_reg))\n",
    "print('Eout: ', get_err_rate(X_test, y_test, w_reg))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- 结果表现出来的是高偏差"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "19. Implement the fixed learning rate gradient descent algorithm for logistic regression. Run the algorithm with $\\eta=0.01$ and $T = 2000$, what is $E_{out}(g)$ from your algorithm, evaluated using the 0/1 error on the test set?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "episode: 2000"
     ]
    }
   ],
   "source": [
    "w_reg, _ = logistic_reg(X, y, eta=0.01, updates=2000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Ein:   0.197\n",
      "Eout:  0.22\n"
     ]
    }
   ],
   "source": [
    "print('Ein:  ', get_err_rate(X, y, w_reg))\n",
    "print('Eout: ', get_err_rate(X_test, y_test, w_reg))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- 偏差仍旧很高"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "episode: 50000\n",
      "Ein:   0.172\n",
      "Eout:  0.18166666666666667\n"
     ]
    }
   ],
   "source": [
    "# 调参后的结果（应多尝试其他数值）\n",
    "w_reg, _ = logistic_reg(X, y, eta=0.005, updates=50000)\n",
    "print()\n",
    "print('Ein:  ', get_err_rate(X, y, w_reg))\n",
    "print('Eout: ', get_err_rate(X_test, y_test, w_reg))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- 要考虑增加模型复杂度"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "data_train shape:  (1000, 21)\n",
      "data_test shape:  (3000, 21)\n",
      "X shape:  (1000, 231)\n",
      "X_test shape:  (3000, 231)\n",
      "[[1.00000000e+00 5.75480000e-01 5.39380000e-01 7.23110000e-01\n",
      "  2.37020000e-01 9.58640000e-01 8.52080000e-01 6.86420000e-01\n",
      "  3.54310000e-01 9.54350000e-03 9.35230000e-01 2.16560000e-01\n",
      "  6.21070000e-01 3.13710000e-01 8.26750000e-01 6.16550000e-01\n",
      "  1.74680000e-01 1.20090000e-01 3.83170000e-01 3.55140000e-01\n",
      "  3.94390000e-01 3.31177230e-01 3.10402402e-01 4.16135343e-01\n",
      "  1.36400270e-01 5.51678147e-01 4.90354998e-01 3.95020982e-01\n",
      "  2.03898319e-01 5.49209338e-03 5.38206160e-01 1.24625949e-01\n",
      "  3.57413364e-01 1.80533831e-01 4.75778090e-01 3.54812194e-01\n",
      "  1.00524846e-01 6.91093932e-02 2.20506672e-01 2.04375967e-01\n",
      "  2.26963557e-01 2.90930784e-01 3.90031072e-01 1.27843848e-01\n",
      "  5.17071243e-01 4.59594910e-01 3.70241220e-01 1.91107728e-01\n",
      "  5.14757303e-03 5.04444357e-01 1.16808133e-01 3.34992737e-01\n",
      "  1.69208900e-01 4.45932415e-01 3.32554739e-01 9.42188984e-02\n",
      "  6.47741442e-02 2.06674235e-01 1.91555413e-01 2.12726078e-01\n",
      "  5.22888072e-01 1.71391532e-01 6.93202170e-01 6.16147569e-01\n",
      "  4.96357166e-01 2.56205104e-01 6.90100029e-03 6.76274165e-01\n",
      "  1.56596702e-01 4.49101928e-01 2.26846838e-01 5.97831193e-01\n",
      "  4.45833471e-01 1.26312855e-01 8.68382799e-02 2.77074059e-01\n",
      "  2.56805285e-01 2.85187353e-01 5.61784804e-02 2.27216853e-01\n",
      "  2.01960002e-01 1.62695268e-01 8.39785562e-02 2.26200037e-03\n",
      "  2.21668215e-01 5.13290512e-02 1.47206011e-01 7.43555442e-02\n",
      "  1.95956285e-01 1.46134681e-01 4.14026536e-02 2.84637318e-02\n",
      "  9.08189534e-02 8.41752828e-02 9.34783178e-02 9.18990650e-01\n",
      "  8.16837971e-01 6.58029669e-01 3.39655738e-01 9.14878084e-03\n",
      "  8.96548887e-01 2.07603078e-01 5.95382545e-01 3.00734954e-01\n",
      "  7.92555620e-01 5.91049492e-01 1.67455235e-01 1.15123078e-01\n",
      "  3.67322089e-01 3.40451410e-01 3.78078030e-01 7.26040326e-01\n",
      "  5.84884754e-01 3.01900465e-01 8.13182548e-03 7.96890778e-01\n",
      "  1.84526445e-01 5.29201326e-01 2.67306017e-01 7.04457140e-01\n",
      "  5.25349924e-01 1.48841334e-01 1.02326287e-01 3.26491494e-01\n",
      "  3.02607691e-01 3.36051831e-01 4.71172416e-01 2.43205470e-01\n",
      "  6.55084927e-03 6.41960577e-01 1.48651115e-01 4.26314869e-01\n",
      "  2.15336818e-01 5.67497735e-01 4.23212251e-01 1.19903846e-01\n",
      "  8.24321778e-02 2.63015551e-01 2.43775199e-01 2.70717184e-01\n",
      "  1.25535576e-01 3.38135748e-03 3.31361341e-01 7.67293736e-02\n",
      "  2.20051312e-01 1.11150590e-01 2.92925793e-01 2.18449831e-01\n",
      "  6.18908708e-02 4.25490879e-02 1.35760963e-01 1.25829653e-01\n",
      "  1.39736321e-01 9.10783923e-05 8.92536751e-03 2.06674036e-03\n",
      "  5.92718155e-03 2.99389139e-03 7.89008862e-03 5.88404492e-03\n",
      "  1.66705858e-03 1.14607892e-03 3.65678290e-03 3.38927859e-03\n",
      "  3.76386096e-03 8.74655153e-01 2.02533409e-01 5.80843296e-01\n",
      "  2.93391003e-01 7.73201403e-01 5.76616057e-01 1.63365976e-01\n",
      "  1.12311771e-01 3.58352079e-01 3.32137582e-01 3.68845360e-01\n",
      "  4.68982336e-02 1.34498919e-01 6.79370376e-02 1.79040980e-01\n",
      "  1.33520068e-01 3.78287008e-02 2.60066904e-02 8.29792952e-02\n",
      "  7.69091184e-02 8.54090984e-02 3.85727945e-01 1.94835870e-01\n",
      "  5.13469622e-01 3.82920709e-01 1.08488508e-01 7.45842963e-02\n",
      "  2.37975392e-01 2.20566800e-01 2.44943797e-01 9.84139641e-02\n",
      "  2.59359742e-01 1.93417901e-01 5.47988628e-02 3.76734339e-02\n",
      "  1.20204261e-01 1.11410969e-01 1.23724087e-01 6.83515563e-01\n",
      "  5.09732713e-01 1.44416690e-01 9.92844075e-02 3.16785798e-01\n",
      "  2.93611995e-01 3.26061932e-01 3.80133903e-01 1.07698954e-01\n",
      "  7.40414895e-02 2.36243464e-01 2.18961567e-01 2.43161155e-01\n",
      "  3.05131024e-02 2.09773212e-02 6.69321356e-02 6.20358552e-02\n",
      "  6.88920452e-02 1.44216081e-02 4.60148853e-02 4.26487626e-02\n",
      "  4.73622951e-02 1.46819249e-01 1.36078994e-01 1.51118416e-01\n",
      "  1.26124420e-01 1.40063665e-01 1.55543472e-01]\n",
      " [1.00000000e+00 7.07270000e-01 8.85030000e-01 6.27620000e-01\n",
      "  9.38510000e-01 2.08650000e-01 8.22380000e-01 8.00100000e-02\n",
      "  2.23810000e-01 1.89490000e-01 5.77380000e-01 3.95690000e-01\n",
      "  8.95920000e-01 3.71060000e-01 7.19630000e-01 5.58200000e-01\n",
      "  6.78210000e-02 2.90710000e-01 3.90120000e-01 6.88540000e-01\n",
      "  7.70760000e-02 5.00230853e-01 6.25955168e-01 4.43896797e-01\n",
      "  6.63779968e-01 1.47571885e-01 5.81644703e-01 5.65886727e-02\n",
      "  1.58294099e-01 1.34020592e-01 4.08363553e-01 2.79859666e-01\n",
      "  6.33657338e-01 2.62439606e-01 5.08972710e-01 3.94798114e-01\n",
      "  4.79677587e-02 2.05610462e-01 2.75920172e-01 4.86983686e-01\n",
      "  5.45135425e-02 7.83278101e-01 5.55462529e-01 8.30609505e-01\n",
      "  1.84661509e-01 7.27830971e-01 7.08112503e-02 1.98078564e-01\n",
      "  1.67704335e-01 5.10998621e-01 3.50197521e-01 7.92916078e-01\n",
      "  3.28399232e-01 6.36894139e-01 4.94023746e-01 6.00236196e-02\n",
      "  2.57287071e-01 3.45267904e-01 6.09378556e-01 6.82145723e-02\n",
      "  3.93906864e-01 5.89027646e-01 1.30952913e-01 5.16142136e-01\n",
      "  5.02158762e-02 1.40467632e-01 1.18927714e-01 3.62375236e-01\n",
      "  2.48342958e-01 5.62297310e-01 2.32884677e-01 4.51654181e-01\n",
      "  3.50337484e-01 4.25658160e-02 1.82455410e-01 2.44847114e-01\n",
      "  4.32141475e-01 4.83744391e-02 8.80801020e-01 1.95820112e-01\n",
      "  7.71811854e-01 7.50901851e-02 2.10047923e-01 1.77838260e-01\n",
      "  5.41876904e-01 3.71359022e-01 8.40829879e-01 3.48243521e-01\n",
      "  6.75379951e-01 5.23876282e-01 6.36506867e-02 2.72834242e-01\n",
      "  3.66131521e-01 6.46201675e-01 7.23365968e-02 4.35348225e-02\n",
      "  1.71589587e-01 1.66940865e-02 4.66979565e-02 3.95370885e-02\n",
      "  1.20470337e-01 8.25607185e-02 1.86933708e-01 7.74216690e-02\n",
      "  1.50150799e-01 1.16468430e-01 1.41508517e-02 6.06566415e-02\n",
      "  8.13985380e-02 1.43663871e-01 1.60819074e-02 6.76308864e-01\n",
      "  6.57986238e-02 1.84056868e-01 1.55832786e-01 4.74825764e-01\n",
      "  3.25407542e-01 7.36786690e-01 3.05152323e-01 5.91809319e-01\n",
      "  4.59052516e-01 5.57746340e-02 2.39074090e-01 3.20826886e-01\n",
      "  5.66241525e-01 6.33857609e-02 6.40160010e-03 1.79070381e-02\n",
      "  1.51610949e-02 4.61961738e-02 3.16591569e-02 7.16825592e-02\n",
      "  2.96885106e-02 5.75775963e-02 4.46615820e-02 5.42635821e-03\n",
      "  2.32597071e-02 3.12135012e-02 5.50900854e-02 6.16685076e-03\n",
      "  5.00909161e-02 4.24097569e-02 1.29223418e-01 8.85593789e-02\n",
      "  2.00515855e-01 8.30469386e-02 1.61060390e-01 1.24930742e-01\n",
      "  1.51790180e-02 6.50638051e-02 8.73127572e-02 1.54102137e-01\n",
      "  1.72503796e-02 3.59064601e-02 1.09407736e-01 7.49792981e-02\n",
      "  1.69767881e-01 7.03121594e-02 1.36362689e-01 1.05773318e-01\n",
      "  1.28514013e-02 5.50866379e-02 7.39238388e-02 1.30471445e-01\n",
      "  1.46051312e-02 3.33367664e-01 2.28463492e-01 5.17286290e-01\n",
      "  2.14242623e-01 4.15499969e-01 3.22293516e-01 3.91584890e-02\n",
      "  1.67850140e-01 2.25247486e-01 3.97549225e-01 4.45021409e-02\n",
      "  1.56570576e-01 3.54506585e-01 1.46824731e-01 2.84750395e-01\n",
      "  2.20874158e-01 2.68360915e-02 1.15031040e-01 1.54366583e-01\n",
      "  2.72448393e-01 3.04982024e-02 8.02672646e-01 3.32440075e-01\n",
      "  6.44730910e-01 5.00102544e-01 6.07621903e-02 2.60452903e-01\n",
      "  3.49516310e-01 6.16876757e-01 6.90539299e-02 1.37685524e-01\n",
      "  2.67025908e-01 2.07125692e-01 2.51656603e-02 1.07870853e-01\n",
      "  1.44757927e-01 2.55489652e-01 2.85998206e-02 5.17867337e-01\n",
      "  4.01697466e-01 4.88060262e-02 2.09203637e-01 2.80742056e-01\n",
      "  4.95494040e-01 5.54662019e-02 3.11587240e-01 3.78576822e-02\n",
      "  1.62274322e-01 2.17764984e-01 3.84343028e-01 4.30238232e-02\n",
      "  4.59968804e-03 1.97162429e-02 2.64583285e-02 4.66974713e-02\n",
      "  5.22737140e-03 8.45123041e-02 1.13411785e-01 2.00165463e-01\n",
      "  2.24067640e-02 1.52193614e-01 2.68613225e-01 3.00688891e-02\n",
      "  4.74087332e-01 5.30699090e-02 5.94070978e-03]]\n"
     ]
    }
   ],
   "source": [
    "data_train = read_data_from_file('hw3_train.dat') \n",
    "print('data_train shape: ', data_train.shape)\n",
    "data_test = read_data_from_file('hw3_test.dat')\n",
    "print('data_test shape: ', data_test.shape)\n",
    "y = data_train[:,-1]\n",
    "X = data_train[:,:-1]                 # 不加x0，以便做特征转换特征转换\n",
    "y_test = data_test[:,-1]\n",
    "X_test = data_test[:,:-1]\n",
    "\n",
    "poly = PolynomialFeatures(2)\n",
    "X= poly.fit_transform(X)\n",
    "X_test = poly.fit_transform(X_test)\n",
    "print('X shape: ', X.shape)\n",
    "print('X_test shape: ', X_test.shape)\n",
    "print(X[:2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "episode: 20000\n",
      "Ein:   0.151\n",
      "Eout:  0.18933333333333333\n"
     ]
    }
   ],
   "source": [
    "w_reg, _ = logistic_reg(X, y, eta=0.008, updates=20000)\n",
    "print()\n",
    "print('Ein:  ', get_err_rate(X, y, w_reg))\n",
    "print('Eout: ', get_err_rate(X_test, y_test, w_reg))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- 表现稍好，但出现较高方差，有overfitting，应考虑加大训练的数据量，或换模型"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "20. Implement the fixed learning rate stochastic gradient descent algorithm for logistic regression. Instead of randomly choosing nn in each iteration, please simply pick the example with the cyclic order $n = 1, 2, ..., N, 1, 2, ...$<br/><br/>Run the algorithm with $\\eta=0.001$ and $T = 2000$. What is $E_{out}(g)$ from your algorithm, evaluated using the 0/1 error on the test set?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "episode: 2000\t"
     ]
    }
   ],
   "source": [
    "w_reg_sgd, _ = logistic_reg_sgd(X, y, eta=0.001, updates=2000, random_choice=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Eout:  0.23533333333333334\n"
     ]
    }
   ],
   "source": [
    "print('Eout: ', get_err_rate(X_test, y_test, w_reg_sgd))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "episode: 100000\t\n",
      "Eout:  0.18333333333333332\n"
     ]
    }
   ],
   "source": [
    "# 调参后的结果（应多尝试其他数值）\n",
    "w_reg, _ = logistic_reg_sgd(X, y, eta=0.0005, updates=100000)\n",
    "print()\n",
    "print('Eout: ', get_err_rate(X_test, y_test, w_reg))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
